import torch
import torch.nn as nn
import torch.nn.functional as F
import collections
import numpy as np


from cbml_benchmark.modeling import registry

@registry.BACKBONES.register('googlenet')
class GoogLeNet(nn.Module):
    __constants__ = ['aux_logits', 'transform_input']

    def __init__(self):
        super(GoogLeNet, self).__init__()

        GlobalParams = collections.namedtuple('GlobalParams', [
            "num_classes", "aux_logits", "transform_input",
            "blocks", "dropout_rate", "image_size"
        ])

        # Change namedtuple defaults
        GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)

        params_dict = {
            # Coefficients: aux_logits, transform_input, blocks, image_size
            "googlenet": (True, True, None, 224),
        }

        a, t, b, s = params_dict["googlenet"]
        # note: all models have drop connect rate = 0.2
        global_params = GlobalParams(
            aux_logits=a,
            transform_input=t,
            blocks=b,
            image_size=s,
            dropout_rate=0.2,
            num_classes=1000,
        )

        if global_params.blocks is None:
            blocks = [BasicConv2d, Inception, InceptionAux]
        assert len(blocks) == 3
        conv_block = blocks[0]
        inception_block = blocks[1]
        inception_aux_block = blocks[2]

        self.aux_logits = global_params.aux_logits
        self.transform_input = global_params.transform_input

        self.conv1 = conv_block(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
        self.conv2 = conv_block(64, 64, kernel_size=1)
        self.conv3 = conv_block(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception3a = inception_block(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = inception_block(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = inception_block(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = inception_block(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = inception_block(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = inception_block(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = inception_block(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.inception5a = inception_block(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = inception_block(832, 384, 192, 384, 48, 128, 128)

        if global_params.aux_logits:
            self.aux1 = inception_aux_block(512, global_params.num_classes)
            self.aux2 = inception_aux_block(528, global_params.num_classes)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(p=global_params.dropout_rate)
        self.fc = nn.Linear(1024, global_params.num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
        self.num = [64, 32, 1, 8]
        self.num_clusters = [self.num[0] * self.num[0], self.num[1] * self.num[1], self.num[2] * self.num[2],
                             self.num[3] * self.num[3]]

        self.val3 = torch.randn(self.num_clusters[3], 1024)
        self.apply(self._init_centroids)

    def _init_centroids(self, m):
        self.centroids3 = nn.Parameter((self.val3).to("cuda"))

    def ra3(self, x, index):
        N, C1, H, W = x.shape
        x_flatten = x.view(N, C1, -1)
        # x_flatten = F.normalize(x_flatten, p=2, dim=1)

        # qr, centroids = torch.qr(self.centroids3)

        sim = (torch.matmul(x_flatten.unsqueeze(0).permute(1, 0, 3, 2),
                            F.normalize(self.centroids3, p=2, dim=1).permute(1, 0).unsqueeze(0).unsqueeze(0)).permute(0,
                                                                                                                      1,
                                                                                                                      3,
                                                                                                                      2) / np.sqrt(
            self.num_clusters[index])).squeeze(1)  #
        sim = torch.exp(sim)
        # sim = F.normalize(sim,p=2,dim=1)
        sim = torch.log(1 + sim)
        ra = torch.zeros([N, self.num_clusters[index], C1], dtype=x.dtype, layout=x.layout, device=x.device)
        for C in range(self.num_clusters[index]):
            residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - \
                       F.normalize(self.centroids3[C:C + 1, :], p=2, dim=1).expand(x_flatten.size(-1), -1, -1).permute(
                           1, 2, 0).unsqueeze(0)
            residual *= sim[:, C:C + 1, :].unsqueeze(2)
            ra[:, C:C + 1, :] = residual.sum(dim=-1) / C1
        ra = F.normalize(ra, p=2, dim=2)
        ra = ra.permute(0, 2, 1).view(N, C1, self.num[index], self.num[index])
        return ra

    def extract_features(self, inputs):
        """ Returns output of the final convolution layer """
        bs = inputs.size(0)
        x = self.conv1(inputs)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.maxpool2(x)

        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)
        x = self.inception4a(x)

        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)

        x = self.inception4e(x)
        x = self.maxpool4(x)
        x = self.inception5a(x)
        x = self.inception5b(x)
        return x

    def forward(self, inputs):
        """ Calls extract_features to extract features, applies final linear layer, and returns logits. """
        bs = inputs.size(0)
        # Convolution layers
        x = self.extract_features(inputs)

        # Pooling and final linear layer
        # x = self.avgpool(x)
        # x = x.view(bs, -1)

        # x = self._dropout(x)
        # x = self._fc(x)

        x = self.ra3(x, 3).contiguous()

        # print(np.shape(x0),np.shape(x1),np.shape(x2))
        # x = self.model.fc(x)  --remove
        return x, x, self.centroids3, self.centroids3

    def load_param(self, model_path):
        param_dict = torch.load(model_path)
        for i in param_dict:
            if 'last_linear' in i:
                continue
            self.state_dict()[i].copy_(param_dict[i])


class Inception(nn.Module):
    __constants__ = ['branch2', 'branch3', 'branch4']

    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj,
                 conv_block=None):
        super(Inception, self).__init__()
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            conv_block(in_channels, ch3x3red, kernel_size=1),
            conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1)
        )

        self.branch3 = nn.Sequential(
            conv_block(in_channels, ch5x5red, kernel_size=1),
            conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1)
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
            conv_block(in_channels, pool_proj, kernel_size=1)
        )

    def _forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
        return outputs

    def forward(self, x):
        outputs = self._forward(x)
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):

    def __init__(self, in_channels, num_classes, conv_block=None):
        super(InceptionAux, self).__init__()
        if conv_block is None:
            conv_block = BasicConv2d
        self.conv = conv_block(in_channels, 128, kernel_size=1)

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
        x = F.adaptive_avg_pool2d(x, (4, 4))
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
        x = self.conv(x)
        # N x 128 x 4 x 4
        x = torch.flatten(x, 1)
        # N x 2048
        x = F.relu(self.fc1(x), inplace=True)
        # N x 1024
        x = F.dropout(x, 0.7, training=self.training)
        # N x 1024
        x = self.fc2(x)
        # N x 1000 (num_classes)

        return x


class BasicConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)
